package com.chatplus.application.aiprocessor.util;

import cn.hutool.extra.spring.SpringUtil;
import com.chatplus.application.aiprocessor.constant.AIConstants;
import com.chatplus.application.common.constant.GroupCacheNames;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.common.util.CacheGroupUtils;
import com.chatplus.application.common.util.IdGenerator;
import com.chatplus.application.common.util.MessageTokenUtil;
import com.chatplus.application.domain.dto.AdminConfigDto;
import com.chatplus.application.domain.dto.ws.WsChatMessage;
import com.chatplus.application.domain.entity.chat.ChatHistoryEntity;
import com.chatplus.application.domain.entity.chat.ChatItemEntity;
import com.chatplus.application.domain.request.ws.ChatWebSocketRequest;
import com.chatplus.application.enumeration.MessageTypeEnum;
import com.chatplus.application.service.chat.ChatHistoryService;
import com.chatplus.application.service.chat.ChatItemService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.WebSocketSession;

/**
 * 会话工具类，比如发送消息。保存历史记录等等
 * 基于Spring WebSocket
 */
public class ChatWebSocketUtil {
    private ChatWebSocketUtil() {
    }

    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(ChatWebSocketUtil.class);

    /**
     * 保存聊天记录
     *
     * @param request 请求
     */
    public static void saveChatHistory(ChatWebSocketRequest request) {
        // 保存记录
        saveChatItemRecord(request);
        // 保存历史记录
        String modelName = request.getChatModel().getValue();
        String prompt = request.getCurrentPrompt();
        String reply = request.getReply();
        Long promptTokens = MessageTokenUtil.countMessageTokens(prompt, modelName);
        Long replyTokens = request.getReplyToken() == null ? MessageTokenUtil.countMessageTokens(reply, modelName) : request.getReplyToken();
        Long itemId = IdGenerator.generateLongId();
        saveChatHistory(request, itemId, "prompt", promptTokens, prompt);
        saveChatHistory(request, itemId, "reply", replyTokens, reply);
        request.setHistoryItemId(itemId);
    }

    /**
     * 保存聊天会话框
     *
     * @param chatRequest prompt
     */
    private static void saveChatItemRecord(ChatWebSocketRequest chatRequest) {
        String chatId = chatRequest.getChatId();
        ChatItemService chatItemService = SpringUtil.getBean(ChatItemService.class);
        // 先检查看是否已有记录，有就不存了
        ChatItemEntity chatItemEntity = chatItemService.getChatItemByChatId(chatId);
        if (chatItemEntity != null) {
            return;
        }
        chatItemEntity = new ChatItemEntity();
        chatItemEntity.setChatId(chatRequest.getChatId());
        chatItemEntity.setModelId(chatRequest.getModelId());
        chatItemEntity.setRoleId(chatRequest.getRoleId());
        chatItemEntity.setUserId(chatRequest.getUserId());
        String prompt = chatRequest.getCurrentPrompt();
        if (prompt.length() > 30) {
            chatItemEntity.setTitle(prompt.substring(0, 30) + "...");
        } else {
            chatItemEntity.setTitle(prompt);
        }
        chatItemService.save(chatItemEntity);
    }

    /**
     * 保存聊天记录
     *
     * @param type    prompt or reply
     * @param tokens  tokens
     * @param content content
     */
    private static void saveChatHistory(ChatWebSocketRequest chatRequest, Long itemId, String type, Long tokens, String content) {
        ChatHistoryService chatHistoryService = SpringUtil.getBean(ChatHistoryService.class);
        Long userId = chatRequest.getUserId();
        String chatId = chatRequest.getChatId();
        ChatHistoryEntity chatHistoryEntity = new ChatHistoryEntity();
        chatHistoryEntity.setChatId(chatId);
        chatHistoryEntity.setItemId(itemId);
        chatHistoryEntity.setType(type);
        chatHistoryEntity.setContent(content);
        chatHistoryEntity.setUserId(userId);
        chatHistoryEntity.setRoleId(chatRequest.getRoleId());
        chatHistoryEntity.setUseContext(chatRequest.getUseContext());
        if (type.equals("prompt")) {
            chatHistoryEntity.setTokens(tokens);
        } else {
            // 根据ID排序获取最新的一条
            ChatHistoryEntity lastEntity = chatHistoryService.getLastReplyHistoryByUserId(userId, chatId);
            Long newsTokens = lastEntity != null ? lastEntity.getTokens() : 0L;
            chatHistoryEntity.setTokens(newsTokens + tokens);
        }
        chatHistoryService.save(chatHistoryEntity);
    }

    /**
     * 发送流消息
     *
     * @param sessionId 会话id
     * @param message   消息
     */
    public static void replyStreamMessage(String sessionId, String message) {
        WebSocketSession session = WebSocketManager.get(sessionId);
        if (session == null || !session.isOpen()) {
            LOGGER.message("session会话已经关闭").context("message", message).error();
            return;
        }
        try {
            if (MessageTypeEnum.WS_END.getValue().equals(message)) {
                WsChatMessage end = new WsChatMessage(MessageTypeEnum.WS_END, null);
                session.sendMessage(new BinaryMessage(end.toByte()));
                return;
            }
            if (MessageTypeEnum.WS_START.getValue().equals(message)) {
                WsChatMessage start = new WsChatMessage(MessageTypeEnum.WS_START, null);
                session.sendMessage(new BinaryMessage(start.toByte()));
                return;
            }
            WsChatMessage chunk = new WsChatMessage(MessageTypeEnum.WS_MIDDLE, message);
            session.sendMessage(new BinaryMessage(chunk.toByte()));
        } catch (Exception e) {
            LOGGER.message("发送消息异常").exception(e).error();
        }
    }

    /**
     * 发送完整消息
     *
     * @param sessionId 会话id
     * @param message   消息
     */
    public static void replyFullMessage(String sessionId, String message) {
        replyStreamMessage(sessionId, MessageTypeEnum.WS_START.getValue());
        replyStreamMessage(sessionId, message);
        replyStreamMessage(sessionId, MessageTypeEnum.WS_END.getValue());
    }

    /**
     * 回复固定错误信息
     *
     * @param sessionId 会话id
     */
    public static void replyErrorMessage(String sessionId) {
        replyFullMessage(sessionId, AIConstants.ERROR_MSG);
        replyWechatCard(sessionId);
    }

    /**
     * 回复错误自定义信息
     *
     * @param sessionId 会话id
     * @param errorMsg  自定义的错误信息
     */
    public static void replyErrorMessage(String sessionId, String errorMsg) {
        replyFullMessage(sessionId, errorMsg);
        replyWechatCard(sessionId);
    }

    private static void replyWechatCard(String sessionId) {
        AdminConfigDto adminConfigDto = CacheGroupUtils.get(GroupCacheNames.SYS_AI_SETTING, AIConstants.SYSTEM_CONFIG_REDIS_KEY);
        if (adminConfigDto == null) {
            return;
        }
        String wxCard = adminConfigDto.getWechatCardUrl();
        if (StringUtils.isEmpty(wxCard)) {
            return;
        }
        String errImg = String.format("![](%s)", wxCard);
        replyFullMessage(sessionId, errImg);
    }

    /**
     * 回复完整消息
     *
     * @param message   消息
     * @param haveStart 是否包含开始开始标志，包不包含，得看StreamEventSourceListener.onOpen有没有连接成功
     *                  ，如果是在onEvent里面发生错误的，一般是传false的，除非是需要换输出行的
     */
    public static void replyFullMessage(String sessionId, String message, boolean haveStart) {
        if (haveStart) {
            replyFullMessage(sessionId, message);
        } else {
            replyStreamMessage(sessionId, message);
            replyStreamMessage(sessionId, MessageTypeEnum.WS_END.getValue());
        }
    }
}
